{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "afbb588f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pypose as pp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21ff76e3",
   "metadata": {},
   "source": [
    "# 1. Intialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77de72ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a: so3Type LieTensor:\n",
      "tensor([[ 0.6358, -0.1589, -0.9093],\n",
      "        [-0.4168, -0.1306, -1.4639]]) \n",
      "x.shape: torch.Size([2, 1, 7]) \n",
      "x.gshape: torch.Size([2, 1])\n",
      "SE3Type LieTensor:\n",
      "tensor([[0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 1.]])\n",
      "se3Type LieTensor:\n",
      "tensor([[[ 0.8726,  0.3301, -0.9665,  0.2766, -1.2635,  0.4250],\n",
      "         [-1.4630, -0.6028,  0.3675,  0.5606,  0.2667,  1.4805]],\n",
      "\n",
      "        [[ 0.9815, -0.3840,  0.3326,  1.7540, -1.4837,  0.8004],\n",
      "         [ 0.1170, -1.9541, -0.9162,  0.6001, -1.6201,  0.1619]]])\n"
     ]
    }
   ],
   "source": [
    "a = pp.so3(torch.randn(2,3))\n",
    "x = pp.identity_SE3(2,1)\n",
    "y = pp.randn_se3(2,2)\n",
    "print('a:', a, '\\nx.shape:', x.shape, '\\nx.gshape:', x.lshape)\n",
    "print(x.lview(2))\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cff3249",
   "metadata": {},
   "source": [
    "### All arguments in PyTorch are supported"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3003efc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(SO3Type LieTensor:\n",
       " tensor([[-0.1433,  0.4112,  0.1737,  0.8833],\n",
       "         [ 0.2252,  0.1924, -0.5417,  0.7866],\n",
       "         [ 0.7282,  0.2878, -0.3736,  0.4973]], device='cuda:0',\n",
       "        dtype=torch.float64, requires_grad=True),\n",
       " SO3Type LieTensor:\n",
       " tensor([[-0.1433,  0.4112,  0.1737,  0.8833],\n",
       "         [ 0.2252,  0.1924, -0.5417,  0.7866],\n",
       "         [ 0.7282,  0.2878, -0.3736,  0.4973]], device='cuda:0',\n",
       "        grad_fn=<AliasBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = pp.randn_SO3(3, device=\"cuda:0\", dtype=torch.double, requires_grad=True)\n",
    "b = pp.identity_like(a, device=\"cpu\")\n",
    "a, b\n",
    "t = a.float()\n",
    "a, t"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "minute-archives",
   "metadata": {},
   "source": [
    "# 2. Slicing and Shaping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "welsh-assembly",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A: torch.Size([2, 2])\n",
      "B: torch.Size([2, 1])\n",
      "C: torch.Size([2, 3])\n",
      "D: torch.Size([3])\n",
      "E: torch.Size([2, 1])\n",
      "F: torch.Size([2, 2])\n"
     ]
    }
   ],
   "source": [
    "A = pp.randn_SO3(2,2)\n",
    "B = pp.randn_SO3(2,1)\n",
    "C = torch.cat([A,B], dim=1)         # Tensor cat\n",
    "C[0,1] = pp.randn_SO3(1)            # Slicing set\n",
    "D = C[1,:].Log()                    # Slicing get\n",
    "E, F = torch.split(C, [1,2], dim=1) # Tensor split\n",
    "print('A:', A.lshape)\n",
    "print('B:', B.lshape)\n",
    "print('C:', C.lshape)\n",
    "print('D:', D.lshape)\n",
    "print('E:', E.lshape)\n",
    "print('F:', F.lshape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7ea61f8",
   "metadata": {},
   "source": [
    "# 3. Basic Operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b6927dac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "se3Type LieTensor:\n",
       "tensor([[[-0.8726, -0.3301,  0.9665, -0.2766,  1.2635, -0.4250],\n",
       "         [ 1.4630,  0.6028, -0.3675, -0.5606, -0.2667, -1.4805]],\n",
       "\n",
       "        [[-0.9815,  0.3840, -0.3326, -1.7540,  1.4837, -0.8004],\n",
       "         [-0.1170,  1.9541,  0.9162, -0.6001,  1.6201, -0.1619]]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(x * y.Exp()).Inv().Log()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f94d83b4",
   "metadata": {},
   "source": [
    "# 4. Adjoint Transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "23028575",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(True)\n"
     ]
    }
   ],
   "source": [
    "X = pp.randn_Sim3(6, dtype=torch.double)\n",
    "a = pp.randn_sim3(6, dtype=torch.double)\n",
    "b = X.AdjT(a)\n",
    "print((X * b.Exp() - a.Exp() * X).abs().mean() < 1e-7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c8a2cca5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(True)\n"
     ]
    }
   ],
   "source": [
    "X = pp.randn_SE3(8)\n",
    "a = pp.randn_se3(8)\n",
    "b = X.Adj(a)\n",
    "print((b.Exp() * X - X * a.Exp()).abs().mean() < 1e-7)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fef005df",
   "metadata": {},
   "source": [
    "# 5. Gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "331ff3f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.1383, device='cuda:0', grad_fn=<SumBackward0>),\n",
       " tensor([[-0.3501, -0.0821,  0.2058],\n",
       "         [ 0.0680,  0.2020, -0.3399],\n",
       "         [ 0.1391, -0.2679,  0.3593]], device='cuda:0'),\n",
       " so3Type LieTensor:\n",
       " tensor([[-0.1751, -0.0410,  0.1029],\n",
       "         [ 0.0340,  0.1010, -0.1700],\n",
       "         [ 0.0695, -0.1340,  0.1798]], device='cuda:0', requires_grad=True),\n",
       " so3Type LieTensor:\n",
       " tensor([[-0.1751, -0.0410,  0.1029],\n",
       "         [ 0.0340,  0.1010, -0.1700],\n",
       "         [ 0.0695, -0.1340,  0.1798]], device='cuda:0'))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = pp.randn_so3(3, sigma=0.1, requires_grad=True, device=\"cuda\")\n",
    "assert x.is_leaf\n",
    "loss = (x.Exp().Log()**2).sin().sum() # Just test, No physical meaning\n",
    "loss.backward()\n",
    "y = x.detach()\n",
    "loss, x.grad, x, y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff7c91bb",
   "metadata": {},
   "source": [
    "# 6. Test a Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0bd984bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before Optimization:\n",
      " so3Type LieTensor:\n",
      "Parameter containing:\n",
      "tensor([[-1.9906,  0.1487, -0.9390],\n",
      "        [ 0.2977,  1.7627,  0.6284],\n",
      "        [ 0.6836,  1.1660,  0.4401],\n",
      "        [-0.3382, -1.3075, -0.2497]], device='cuda:0', requires_grad=True)\n",
      "tensor(7.0022, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(7.2108, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(6.7905, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(6.8941, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(6.7523, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "Parameter: 12\n",
      "After Optimization:\n",
      " so3Type LieTensor:\n",
      "Parameter containing:\n",
      "tensor([[-0.7525, -0.1441, -1.3598],\n",
      "        [ 0.6561,  0.0567, -0.1614],\n",
      "        [ 0.4767, -0.3290, -0.7701],\n",
      "        [-0.8772, -1.9518,  0.2295]], device='cuda:0', requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "from torch import nn\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "class TestNet(nn.Module):\n",
    "    def __init__(self, n):\n",
    "        super().__init__()\n",
    "        self.weight = pp.Parameter(pp.randn_so3(n))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.weight.Exp() * x\n",
    "\n",
    "\n",
    "n,epoch = 4, 5\n",
    "net = TestNet(n).cuda()\n",
    "\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.2, momentum=0.9)\n",
    "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4], gamma=0.5)\n",
    "\n",
    "print(\"Before Optimization:\\n\", net.weight)\n",
    "for i in range(epoch):\n",
    "    optimizer.zero_grad()\n",
    "    inputs = pp.randn_SO3(n).cuda()\n",
    "    outputs = net(inputs)\n",
    "    loss = outputs.abs().sum()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    print(loss)\n",
    "\n",
    "print(\"Parameter:\", count_parameters(net))\n",
    "print(\"After Optimization:\\n\", net.weight)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
